library(tidyverse); library(parsnip);
library(recipes); library(rsample); library(yardstick); library(furrr)

plan(multisession, workers = availableCores() - 1)

fit_lm <- function(data, covars, dv){
  covars <- ifelse(length(covars) == 0, 1, covars)
  formula <- as.formula(paste(dv, " ~ ", paste(covars, collapse = "+")))
  linear_reg() %>%
    set_engine("lm") %>%
    fit(formula, data = data)
}

calc_cv_mse <- function(data, coefs, dv){
  vfold_cv(
    data = data,
    v = 5) %>%
    mutate(models_lm = map(splits,
                           ~ fit_lm(analysis(.x), coefs, dv)),
           rmse = map2_dbl(splits, models_lm,
                           ~ rmse(bind_cols(predict(.y, assessment(.x)), assessment(.x)),
                                  !! sym(dv), .pred)$.estimate)) %>%
    pull(rmse) %>%
    mean()
}


lasso_covars <- function(data, dv, covars, treat){
  lasso_treat_data <- filter(data, !! sym(treat) == 1)
  lasso_contr_data <- filter(data, !! sym(treat) == 0)

  lasso_mod <- linear_reg(mixture = 1) %>%
    set_engine("glmnet")
  lasso_treat_mod <- lasso_mod %>%
    fit(as.formula(paste(dv, " ~ .")) , data = lasso_treat_data)
  lasso_contr_mod <- lasso_mod %>%
    fit(as.formula(paste(dv, " ~ .")) , data = lasso_contr_data)

  nonzero_coefs_treat <- tibble(lambda = lasso_treat_mod$fit$lambda,
                          nonzero_coefs = apply(lasso_treat_mod$fit$beta, 2, function(x) names(x[x > 0])),
                          cv_mse = future_map_dbl(nonzero_coefs, ~ calc_cv_mse(data = lasso_treat_data,
                                                                        coefs = .x, dv = dv), .progress = FALSE)) %>%
    filter(cv_mse == min(cv_mse)) %>%
    pull(nonzero_coefs) %>%
    unlist()

  nonzero_coefs_contr <- tibble(lambda = lasso_contr_mod$fit$lambda,
                                nonzero_coefs = apply(lasso_contr_mod$fit$beta, 2, function(x) names(x[x > 0])),
                                cv_mse = future_map_dbl(nonzero_coefs, ~ calc_cv_mse(data = lasso_contr_data,
                                                                              coefs = .x, dv = dv), .progress = FALSE)) %>%
    filter(cv_mse == min(cv_mse)) %>%
    pull(nonzero_coefs) %>%
    unlist()
  unique(c(nonzero_coefs_treat, nonzero_coefs_contr))
}
